#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# -*- mode: python-mode -*-
#
"""
smbwalk is script to index all browsable files available through SMB shares. From one (or
many IP addresses), it will crawl through the shares available for the specified user (default:
guest). On top of enumerating the files available, it will fingerprint the file type using the
libmagic.

Inputs can be provided either directly by specifying the IP addresses on the command line:
Ex:
$ ./smbwalk.py 10.0.0.1 10.0.0.2

Or by pointing to a SQLite database generated by the script `unmap.py`.
Ex:
$ ./smbwalk.py --sqlite /path/to/unmap.sqlite

By default, the script will spit all the entries on the stdout which is not optimum for
big networks, but it can store all the paths found within a SQLite database.
Ex:
$ ./smbwalk.py --outfile /path/to/smbwalk.sqlite --sqlite /path/to/unmap.sqlite

Use `-h` to see all the different options.
"""

import sys
import argparse
import logging
import threading
import re
import sqlite3
import unmap
import tempfile
import os

try:
    import magic
except ImportError as ie:
    print("[-] Failed to import libmagic bindings: run `pip install python-magic`")
    sys.exit(1)

try:
    from impacket import smb, version, smb3, nt_errors
    from impacket.smbconnection import *
except ImportError as ie:
    print("[-] Failed to import impacket suite: run `pip install impacket`")
    sys.exit(1)


__author__    =   "@_hugsy_"
__version__   =   0.1
__licence__   =   "WTFPL v.2"
__file__      =   "smbwalk.py"
__desc__      =   "SMB walker : use `pydoc {:s}` for man page".format(__file__)
__usage__     =   """{3} v{0}\nby {2} under {1}\nsyntax: {3} [options] args""".format(__version__, __licence__, __author__, __file__)

BLACKLISTED_SHARES = ["ADMIN$", "IPC$", "D$", "C$"]

q = []
logger = logging.getLogger( __file__ )
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s %(asctime)s - %(name)s:%(threadName)s - %(message)s",
                                       datefmt="%d/%m/%Y-%H:%M:%S"))
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
logger.propagate = 0

def get_dialect(smb):
    dialect = smb.getDialect()
    if dialect == SMB_DIALECT:
        return "SMBv1"
    elif dialect == SMB2_DIALECT_002:
        return "SMBv2.0"
    elif dialect == SMB2_DIALECT_21:
        return "SMBv2.1"
    else:
        return "SMBv3.0 (%s)" % dialect


def safe_smbconnect(host, port=445):
    try:
        smb = SMBConnection(host, host, sess_port=port)
        if verbose:
            logger.debug( "Connected to {}:{} using {}".format( host, port, get_dialect(smb) ))
    except:
        if verbose:
            logger.error("Failed to connect to {}:{}".format(host, port))
        return None
    return smb


def safe_smblogin(smb, **kwargs):
    username = kwargs.get("user")
    password = kwargs.get("pwd")
    domain = kwargs.get("domain")
    lmhash = kwargs.get("lmhash")
    nthash = kwargs.get("nthash")

    try:
        if lmhash and nthash:
            if verbose:
                logger.debug("Logging using Pass-The-Hash")
            smb.login(username, '', domain=domain, lmhash=lmhash, nthash=nthash)
        else:
            if verbose:
                logger.debug("Logging using password")
            smb.login(username, password, domain=domain)

        if verbose:
            if smb.isGuestSession() > 0:
                logger.debug("GUEST Session Granted")
            else:
                logger.debug("USER Session Granted")
    except Exception as e:
        if verbose:
            logger.error("Failed to login: %s" % e)
        return False

    return True


def safe_enumshares(smb):
    try:
        return [ share['shi1_netname'][:-1] for share in smb.listShares() ]
    except Exception as e:
        if verbose:
            logger.error("Got exception while getting shares: %s" % e)
        return []


def smbwalk(smb, share, regex, path='\\', tid=None, *args, **kwargs):
    max_size = kwargs.get("max_size") or 100*1024
    ip = smb.getRemoteHost()
    try:
        if tid is None:
            tid = smb.connectTree(share)

    except Exception as e:
        if verbose:
            logger.warn("Failed to connect to tree '{}': {}".format(share, e))
        return

    path = ntpath.normpath(path)

    try:
        gen = smb.listPath(share, ntpath.join(path, '*'))
    except Exception as e:
        if verbose:
            logger.warn("Failed to list share '{}'".format(share, e))
        return

    for f in gen:
        cur_path = ntpath.join(path, f.get_longname())
        if f.is_directory() and f.get_longname() not in (".", ".."):
            try:
                smbwalk(smb, share, regex, cur_path + "\\", tid)
            except Exception as e:
                if verbose:
                    logger.warn("Failed to list path '{}': {}".format(cur_path, e))
            continue

        if f.get_longname() in (".", ".."):
            continue


        if regex is None or regex.search(cur_path):
            try:
                entry = [ip, share, cur_path,]
                fhandle = smb.openFile(tid, cur_path, desiredAccess=FILE_READ_DATA, shareMode=FILE_SHARE_READ)
                fdata = smb.readFile(tid, fhandle, offset=0, bytesToRead=max_size)
                fmagic = magic.from_buffer(fdata)
                smb.closeFile(tid, fhandle)
                entry.append(fmagic)
            except Exception as e:
                if verbose:
                    logger.warn("Failed to retrieve file: {}".format(e))
                entry.append("")

            if verbose:
                logger.info( "\\\\" + " -> ".join(entry) )

            q.append( entry )

    return tid


def scan_host(host, port, **kwargs):
    regex = kwargs.get("regex", None)
    if regex is not None:
        regex = re.compile(regex, re.I)

    smb = safe_smbconnect(host, port)
    if smb is None:
        return

    if not safe_smblogin(smb, **kwargs):
        return

    sharenames = safe_enumshares(smb)
    if sharenames is None:
        return

    if verbose:
        logger.info("Found {0:d} shares: {1:s}".format(len(sharenames), sharenames))

    for share in sharenames:
        if share in BLACKLISTED_SHARES:
            continue
        smbwalk(smb, share, regex)

    smb.logoff()
    del(smb)
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(usage = __usage__, description = __desc__, prog = __file__)

    parser.add_argument("-p", "--port", type=int, default=445,
                        help="Use any alternative SMB port (default: 445)")

    parser.add_argument("-U", "--username", type=str, default="guest",
                        help="Username to use to connect")

    parser.add_argument("-D", "--domain", type=str, default="",
                        help="Domain associated to username")

    parser.add_argument("-P", "--password", type=str, default="guest",
                        help="Password associated to username")

    parser.add_argument("-r", "--regex", type=str, default=None,
                        help="Grep matching regex")

    parser.add_argument("-s", "--sql", dest="db", type=str, metavar="db.sqlite",
                        default=None, help="Read IP addresses from unmap SQLite database")

    parser.add_argument("-o", "--outfile", type=str, metavar="/path/to/result",
                        default=None, help="Write results to file (.txt) or database (.sqlite)")

    parser.add_argument("iplist", type=str, metavar="ip1 [ip2]*", nargs='*',
                        help="Specify IP addresses to query")

    parser.add_argument("-v", "--verbose", action="count", dest="verbose",
                        help="Increments verbosity")

    parser.add_argument("-t", "--threads", dest="threads", type=int, metavar="N",
                        default=20, help="Specify number of threads to use")

    parser.add_argument("-V", "--version", action="version", version=__version__)

    args = parser.parse_args()
    verbose = args.verbose


    if not args.iplist and not args.db:
        logger.error("Missing target(s) or database. Use `unmap.py -t sql`")
        exit(1)

    if args.db:
        if verbose:
            logger.debug("Reading IP from '{}'".format(args.db))
        db = unmap.Database(args.db)
        iplist = db.get_ip_by_ports(args.port)

    else:
        iplist = args.iplist

    if ":" in args.password and len(args.password)==65:
        password = None
        lmhash, nthash = args.password.split(":", 1)
    else:
        password = args.password
        lmhash, nthash = None, None

    if verbose:
        logger.debug("Targets: {}".format(iplist))
        if lmhash and nthash:
            logger.debug("Credentials: username={} hash={}:{}".format(args.username, lmhash, nthash))
        else:
            logger.debug("Credentials: username={} password={}".format(args.username, password))


    T = []

    for ip in iplist:
        t = threading.Thread( target=scan_host,
                              args=(ip, args.port,),
                              kwargs={"user": args.username,
                                      "pwd": password,
                                      "domain": args.domain,
                                      "regex": args.regex,
                                      "lmhash": lmhash,
                                      "nthash": nthash, } )
        t.daemon = True
        t.start()
        T.append(t)

        if len(T) == args.threads:
            for t in T:
                t.join()
            T = []


    for t in T:
        t.join()


    if not args.outfile:
        for line in q:
            print("\t".join(line))
        sys.exit(0)

    if args.outfile.endswith(".sqlite"):
        conn = sqlite3.connect(args.outfile)
        conn.execute("CREATE TABLE IF NOT EXISTS files (ip VARCHAR(55), share VARCHAR(255), filepath VARCHAR(255), magic VARCHAR(255))")
        conn.executemany("INSERT INTO files(ip, share, filepath, magic) VALUES (?,?,?,?)", q)
        conn.commit()
        conn.close()

        if verbose:
            logger.info("Written results in database '{}'".format(args.outfile))

    elif args.outfile.endswith(".txt"):
        with open(args.outfile, "w") as f:
            for e in q:
                f.write( ';'.join(e) + '\n' )

        if verbose:
            logger.info("Written results in text file '{}'".format(args.outfile))

    else:
        logger.error("Unknown extension")

    sys.exit(0)
